{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Variational AutoEncoder\n",
    "\n",
    "**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
    "**Date created:** 2020/05/03<br>\n",
    "**Last modified:** 2020/05/03<br>\n",
    "**Description:** Convolutional Variational AutoEncoder (VAE) trained on MNIST digits."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Create a sampling layer\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class Sampling(layers.Layer):\n",
    "    \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n",
    "\n",
    "    def call(self, inputs):\n",
    "        z_mean, z_log_var = inputs\n",
    "        batch = tf.shape(z_mean)[0]\n",
    "        dim = tf.shape(z_mean)[1]\n",
    "        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))\n",
    "        return z_mean + tf.exp(0.5 * z_log_var) * epsilon\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Build the encoder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "latent_dim = 2\n",
    "\n",
    "encoder_inputs = keras.Input(shape=(28, 28, 1))\n",
    "x = layers.Conv2D(32, 3, activation=\"relu\", strides=2, padding=\"same\")(encoder_inputs)\n",
    "x = layers.Conv2D(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
    "x = layers.Flatten()(x)\n",
    "x = layers.Dense(16, activation=\"relu\")(x)\n",
    "z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
    "z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
    "z = Sampling()([z_mean, z_log_var])\n",
    "encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name=\"encoder\")\n",
    "encoder.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Build the decoder\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "latent_inputs = keras.Input(shape=(latent_dim,))\n",
    "x = layers.Dense(7 * 7 * 64, activation=\"relu\")(latent_inputs)\n",
    "x = layers.Reshape((7, 7, 64))(x)\n",
    "x = layers.Conv2DTranspose(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
    "x = layers.Conv2DTranspose(32, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
    "decoder_outputs = layers.Conv2DTranspose(1, 3, activation=\"sigmoid\", padding=\"same\")(x)\n",
    "decoder = keras.Model(latent_inputs, decoder_outputs, name=\"decoder\")\n",
    "decoder.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Define the VAE as a `Model` with a custom `train_step`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class VAE(keras.Model):\n",
    "    def __init__(self, encoder, decoder, **kwargs):\n",
    "        super(VAE, self).__init__(**kwargs)\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "\n",
    "    def train_step(self, data):\n",
    "        if isinstance(data, tuple):\n",
    "            data = data[0]\n",
    "        with tf.GradientTape() as tape:\n",
    "            z_mean, z_log_var, z = encoder(data)\n",
    "            reconstruction = decoder(z)\n",
    "            reconstruction_loss = tf.reduce_mean(\n",
    "                keras.losses.binary_crossentropy(data, reconstruction)\n",
    "            )\n",
    "            reconstruction_loss *= 28 * 28\n",
    "            kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)\n",
    "            kl_loss = tf.reduce_mean(kl_loss)\n",
    "            kl_loss *= -0.5\n",
    "            total_loss = reconstruction_loss + kl_loss\n",
    "        grads = tape.gradient(total_loss, self.trainable_weights)\n",
    "        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
    "        return {\n",
    "            \"loss\": total_loss,\n",
    "            \"reconstruction_loss\": reconstruction_loss,\n",
    "            \"kl_loss\": kl_loss,\n",
    "        }\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Train the VAE\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
    "mnist_digits = np.concatenate([x_train, x_test], axis=0)\n",
    "mnist_digits = np.expand_dims(mnist_digits, -1).astype(\"float32\") / 255\n",
    "\n",
    "vae = VAE(encoder, decoder)\n",
    "vae.compile(optimizer=keras.optimizers.Adam())\n",
    "vae.fit(mnist_digits, epochs=30, batch_size=128)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Display a grid of sampled digits\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "def plot_latent(encoder, decoder):\n",
    "    # display a n*n 2D manifold of digits\n",
    "    n = 30\n",
    "    digit_size = 28\n",
    "    scale = 2.0\n",
    "    figsize = 15\n",
    "    figure = np.zeros((digit_size * n, digit_size * n))\n",
    "    # linearly spaced coordinates corresponding to the 2D plot\n",
    "    # of digit classes in the latent space\n",
    "    grid_x = np.linspace(-scale, scale, n)\n",
    "    grid_y = np.linspace(-scale, scale, n)[::-1]\n",
    "\n",
    "    for i, yi in enumerate(grid_y):\n",
    "        for j, xi in enumerate(grid_x):\n",
    "            z_sample = np.array([[xi, yi]])\n",
    "            x_decoded = decoder.predict(z_sample)\n",
    "            digit = x_decoded[0].reshape(digit_size, digit_size)\n",
    "            figure[\n",
    "                i * digit_size : (i + 1) * digit_size,\n",
    "                j * digit_size : (j + 1) * digit_size,\n",
    "            ] = digit\n",
    "\n",
    "    plt.figure(figsize=(figsize, figsize))\n",
    "    start_range = digit_size // 2\n",
    "    end_range = n * digit_size + start_range + 1\n",
    "    pixel_range = np.arange(start_range, end_range, digit_size)\n",
    "    sample_range_x = np.round(grid_x, 1)\n",
    "    sample_range_y = np.round(grid_y, 1)\n",
    "    plt.xticks(pixel_range, sample_range_x)\n",
    "    plt.yticks(pixel_range, sample_range_y)\n",
    "    plt.xlabel(\"z[0]\")\n",
    "    plt.ylabel(\"z[1]\")\n",
    "    plt.imshow(figure, cmap=\"Greys_r\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_latent(encoder, decoder)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Display how the latent space clusters different digit classes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def plot_label_clusters(encoder, decoder, data, labels):\n",
    "    # display a 2D plot of the digit classes in the latent space\n",
    "    z_mean, _, _ = encoder.predict(data)\n",
    "    plt.figure(figsize=(12, 10))\n",
    "    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=labels)\n",
    "    plt.colorbar()\n",
    "    plt.xlabel(\"z[0]\")\n",
    "    plt.ylabel(\"z[1]\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "(x_train, y_train), _ = keras.datasets.mnist.load_data()\n",
    "x_train = np.expand_dims(x_train, -1).astype(\"float32\") / 255\n",
    "\n",
    "plot_label_clusters(encoder, decoder, x_train, y_train)\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "vae",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}